import pickle
import random
import numpy as np
import symmetries
from inspect import getmembers, isfunction
import itertools
from z3 import *

import signal
	
rewards = {}
vec_mean = []
all_vecs = []
folder = "pickles"

functions_list = [o for o in getmembers(symmetries) if isfunction(o[1])]
functions_name_list = [o[0] for o in functions_list]
num_symmetries = len(functions_name_list)

n_meas = 16

#exec(open("gengraph2.py").read())
all_rewards_vec = pickle.load(open(folder + "/all_rewards_vec" + str(n_meas) + ".pcl", "rb"))
songs = pickle.load(open(folder + "/meas" + str(n_meas) + ".pcl", "rb"))
all_tot_rewards = pickle.load(open(folder + "/all_rewards_tot" + str(n_meas) + ".pcl", "rb"))

rewards = {i:1 for i in functions_name_list}

set_param(timeout=60*1000)
set_param("parallel.enable", True)
#div indicates measures per reference measure
adj_mats = []
simple_adjs = []
prototypes = []
prototype_ids = []
inds = []

for z in range(len(songs)):	
    song = songs[z][:n_meas]
    tot_rewards = all_tot_rewards[z]
    rewards_vec = all_rewards_vec[z]

    tot_cost = Int("tot_cost")
    o = Optimize() #trying to perform maximum weighted set cover
    refs = [BitVec("ref" + str(i), 5) for i in range(len(song))]
    refs_used = [Bool("refs_used" + str(i)) for i in range(len(song))]
    cost_refs = [Int("costref" + str(i)) for i in range(len(song))]

    o.add(Sum([If(refs_used[i],1, 0) for i in range(len(song))]) >= 3)
    o.add(Sum([If(refs_used[i],1, 0) for i in range(len(song))]) <= min(6, len(song)//2))
    ijs = itertools.product(range(len(song)), range(len(song)))
    for i in range(len(song)):
        o.add(BV2Int(refs[i]) < len(song))
        o.add(BV2Int(refs[i]) >= 0)
        #look up smt solvers
        
        for j in range(len(song)):
            o.add(Implies(BV2Int(refs[i]) == j, refs_used[j]))
            o.add(Implies(BV2Int(refs[i]) == j, (cost_refs[i]) == tot_rewards[i][j]))

    #perform optimization
    try:
            h = o.maximize(Sum(cost_refs) - Sum([If(And(refs_used[i], refs_used[j]), tot_rewards[i][j], 0)  for (i,j) in ijs if i != j]) + 5*Sum([If(refs[i] == refs[j], 2*(tot_rewards[i][j] - 8), -1*tot_rewards[i][j] + 8) for (i,j) in ijs if i != j]))
            o.check()
            o.upper(h)
            a = o.model()
    except:
        print("timeout " + str(z))
        continue


    #get values
    refs_used = [i for i in range(len(refs_used)) if a[refs_used[i]]]

    # Store sum of refs_used (number of reference measures used)

    refs = [int(str(a[i])) for i in refs]

    prototypes_ = refs
    mat = np.zeros((len(song) + len(refs_used), len(refs_used) + len(song), len(rewards) + 1), dtype=np.bool) #adjacency matrix with each feature
    simple_mat = np.zeros((len(song), len(refs_used)), dtype=np.bool) #adjacency matrix with yes/no edge


    for i in range(1, len(song)):
        mat[i,i - 1, len(rewards)] = 1


    for i in range(len(song)):
        j = refs[i] #j is ref measure that i is similar to
        j_ind = refs_used.index(refs[i]) #index of reference measure
        simple_mat[i,j_ind] = 1
        vec = [functions_name_list.index(k) for k in rewards_vec[i][j]]
        vec_mean.append(len(vec))
        for val in vec:
            mat[i, j_ind + len(song), val] = 1
    prototype_ids.append((refs_used))
    prototypes.append([song[q] for q in refs_used])
    adj_mats.append(mat)
    simple_adjs.append(simple_mat)
    inds.append(z)
    #Simple_mat is TRUE for the reference measure that is similar to each measure

    pickle.dump(inds, open(folder + "/inds" + str(n_meas) +".pcl", "wb"))
    pickle.dump(prototypes, open(folder + "/prototypes" + str(n_meas) + ".pcl", "wb"))
    pickle.dump(prototype_ids, open(folder + "/prototype-ids" + str(n_meas) + ".pcl", "wb"))
    pickle.dump(adj_mats, open(folder + "/adj_mats" + str(n_meas) + ".pcl", "wb"))
    pickle.dump(simple_adjs, open(folder + "/simple_mats" + str(n_meas) + ".pcl", "wb"))


    vecs = []
    sorted_prototypes = sorted(list(set(refs)))
    rewards_vec = all_rewards_vec[z]

    for h in range(3, 16):
        ref = prototypes_[h]
        shares_ref_prev = [i for i in range(h) if prototypes_[i] == ref and abs(i - h) <= 8]
        if len(shares_ref_prev) == 0:
            ref_ahead = 0
        else:
            ref_ahead = h - shares_ref_prev[-1]
        ref_ahead_np = np.zeros(9)
        ref_ahead_np[ref_ahead] = 1
        prev_syms = [np.zeros(num_symmetries) for i in range(3)]
        for i in range(3):
            neighbor = h - i - 1
            for (k_ind, k) in enumerate(functions_name_list):
                if k in rewards_vec[h][neighbor]:
                    prev_syms[i][k_ind] = 1
        try:
            a = prototypes_[h]
            print(a)
            print(sorted_prototypes)
            print(16 + sorted_prototypes.index(a))
            print(adj_mats[-1].shape)
            ref_syms = adj_mats[-1][h, 16 + sorted_prototypes.index(a)]
        except:
            print("error")
            pass
        vecs.append(np.concatenate([ref_syms, prev_syms[0], prev_syms[1], prev_syms[2], ref_ahead_np]))
    all_vecs.append(vecs)
    print(all_vecs)
    pickle.dump(all_vecs, open("pickles/transformervecs.pcl", "wb"))
